{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "jQg6hsMhbuod"
   },
   "source": [
    "### REQUIRED PIP INSTALLS"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "AagX97p36_7z"
   },
   "outputs": [],
   "source": [
    "!pip install transformers kaggle datasets tqdm"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "K9sCPQzIb278"
   },
   "source": [
    "### DOWLOAD THE DATASET"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "Im-RfeRablcJ",
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "import kaggle\n",
    "from pathlib import Path\n",
    "\n",
    "\n",
    "def download_dataset_from_kaggle(path=\"data\"):\n",
    "    \"\"\"\n",
    "    Download the CodeSearchNet dataset from Kaggle.\n",
    "    Make sure to have the Kaggle API token in ~/.kaggle/kaggle.json\n",
    "\n",
    "    Returns:\n",
    "        str: Path to the downloaded dataset.\n",
    "    \"\"\"\n",
    "    path = Path(path)\n",
    "    path.parent.mkdir(parents=True, exist_ok=True)\n",
    "    kaggle.api.authenticate()\n",
    "    kaggle.api.dataset_download_files(\"omduggineni/codesearchnet\", path=path, unzip=True)\n",
    "\n",
    "\n",
    "download_dataset_from_kaggle()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "tDzmSa0loVOC"
   },
   "source": [
    "### LOAD THE DATASET"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "Z3GQ3tGYnM_z"
   },
   "outputs": [],
   "source": [
    "import glob\n",
    "\n",
    "from datasets import load_dataset\n",
    "from pathlib import Path\n",
    "\n",
    "LANG = \"python\"\n",
    "\n",
    "\n",
    "def load_local_dataset(lang=\"all\", path=\"data\"):\n",
    "    \"\"\"\n",
    "    Load a local dataset from the downloaded Kaggle dataset.\n",
    "\n",
    "    Args:\n",
    "        lang (str): The language to be used for the dataset.\n",
    "        path (str, optional): Path to the downloaded dataset. Defaults to \"data\".\n",
    "\n",
    "    Returns:\n",
    "        Dataset: dataset loaded from local files\n",
    "    \"\"\"\n",
    "    path = Path(path)\n",
    "\n",
    "    if lang != \"all\":\n",
    "        # Read the downloaded dataset\n",
    "        path = path / lang / lang / \"final/jsonl\"\n",
    "        dataset = load_dataset(\n",
    "            \"json\",\n",
    "            data_files={\n",
    "                \"train\": list(sorted(glob.glob(path.as_posix() + \"/train/*.jsonl\"))),\n",
    "                \"validation\": list(sorted(glob.glob(path.as_posix() + \"/valid/*.jsonl\"))),\n",
    "                \"test\": list(sorted(glob.glob(path.as_posix() + \"/test/*.jsonl\"))),\n",
    "            },\n",
    "        )\n",
    "    else:\n",
    "        train_files = glob.glob(path.as_posix() + \"/**/train/*.jsonl\", recursive=True)\n",
    "        valid_files = glob.glob(path.as_posix() + \"/**/valid/*.jsonl\", recursive=True)\n",
    "        test_files = glob.glob(path.as_posix() + \"/**/test/*.jsonl\", recursive=True)\n",
    "        dataset = load_dataset(\n",
    "            \"json\",\n",
    "            data_files={\n",
    "                \"train\": train_files,\n",
    "                \"validation\": valid_files,\n",
    "                \"test\": test_files,\n",
    "            },\n",
    "        )\n",
    "\n",
    "    return dataset\n",
    "\n",
    "\n",
    "dataset = load_local_dataset(LANG)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "iRzz39jDb8AT"
   },
   "source": [
    "### LOAD THE MODEL"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "KkTFNYCi6-PO"
   },
   "outputs": [],
   "source": [
    "from transformers import RobertaTokenizer, T5ForConditionalGeneration\n",
    "\n",
    "tokenizer = RobertaTokenizer.from_pretrained(\"Salesforce/codet5-base\", truncation_side=\"right\")\n",
    "model = T5ForConditionalGeneration.from_pretrained(f\"Salesforce/codet5-base-codexglue-sum-{LANG}\").to(\"cuda\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "3MxfnNxX2n0m"
   },
   "source": [
    "### GENERATE THE SUMMARIES AND ANOTATE THE DATASET"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "background_save": true
    },
    "id": "LtmFMCPVuf3j"
   },
   "outputs": [],
   "source": [
    "from torch.utils.data import DataLoader\n",
    "from tqdm.auto import tqdm\n",
    "import json\n",
    "\n",
    "BATCH_SIZE = 32\n",
    "PARTITION = \"train\"\n",
    "# all_summaries = []\n",
    "\n",
    "\n",
    "def tokenization_collator(batch_sample):\n",
    "    code = list(map(lambda x: x[\"original_string\"], batch_sample))\n",
    "    return tokenizer(code, return_tensors=\"pt\", padding=\"longest\", truncation=True).input_ids.to(\"cuda\")\n",
    "\n",
    "\n",
    "for PARTITION in [\"train\", \"test\", \"validation\"]:\n",
    "    dataloader = DataLoader(dataset[PARTITION], batch_size=BATCH_SIZE, collate_fn=tokenization_collator)\n",
    "\n",
    "    with open(PARTITION + f\"-{LANG}.jsonl\", \"w\") as f:\n",
    "        for batch_num, batch_data in tqdm(\n",
    "            enumerate(iter(dataloader)), total=(len(dataset[PARTITION]) // BATCH_SIZE) + 1\n",
    "        ):\n",
    "            generated_ids = model.generate(batch_data, max_length=512)\n",
    "            summaries = tokenizer.batch_decode(generated_ids.squeeze(), skip_special_tokens=True)\n",
    "\n",
    "            # all_summaries.extend(summaries)\n",
    "            for summary in summaries:\n",
    "                f.write(json.dumps({\"summary\": summary}))\n",
    "                f.write(\"\\n\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "i--G3NMHBBsK"
   },
   "outputs": [],
   "source": [
    "import gzip, json\n",
    "\n",
    "with gzip.open(f\"dataset_{LANG}.jsonl.gz\", \"w\") as w:\n",
    "    for PARTITION in [\"train\", \"test\", \"validation\"]:\n",
    "        with open(PARTITION + f\"-{LANG}.jsonl\") as f:\n",
    "            for line_pos, line in enumerate(f.readlines()):\n",
    "                summary = json.loads(line)\n",
    "                if len(summary[\"summary\"]) > 350:\n",
    "                    continue\n",
    "\n",
    "                d = dataset[PARTITION][line_pos].copy()\n",
    "                d.update(summary)\n",
    "                b = json.dumps(d)\n",
    "                w.write(b.encode(\"UTF-8\"))\n",
    "                w.write(\"\\n\".encode(\"UTF-8\"))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "Qc7ANkjIBBsL"
   },
   "outputs": [],
   "source": [
    "from huggingface_hub import notebook_login\n",
    "\n",
    "notebook_login()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "6-7KPky6BBsM"
   },
   "outputs": [],
   "source": [
    "processed_dataset = load_dataset(\"json\", data_files=f\"./dataset_{LANG}.jsonl.gz\")\n",
    "processed_dataset.push_to_hub(f\"Nan-Do/code-search-net-{LANG}\")"
   ]
  }
 ],
 "metadata": {
  "accelerator": "GPU",
  "colab": {
   "collapsed_sections": [
    "jQg6hsMhbuod",
    "K9sCPQzIb278",
    "tDzmSa0loVOC",
    "iRzz39jDb8AT"
   ],
   "machine_shape": "hm",
   "provenance": []
  },
  "gpuClass": "standard",
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}
